{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "https://tvm.apache.ac.cn/docs/get_started/tutorials/ir_module.html\n",
        "\n",
        "conda_env:tvm-build"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n",
        "\n",
        "# IRModule\n",
        "This tutorial presents the core abstraction of Apache TVM Unity, the IRModule.\n",
        "The IRModule encompasses the **entirety** of the ML models, incorporating the\n",
        "computational graph, tensor programs, and potential calls to external libraries.\n",
        "    :depth: 1\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tvm\n",
        "from tvm import relax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Create IRModule\n",
        "IRModules can be initialized in various ways. We demonstrate a few of them\n",
        "below.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch import nn\n",
        "from tvm import relay, autotvm\n",
        "# from torch.export import export\n",
        "# from tvm.relax.frontend.torch import from_exported_program"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Import from existing models\n",
        "The most common way to initialize an IRModule is to import from an existing\n",
        "model. Apache TVM Unity accommodates imports from a range of frameworks,\n",
        "such as PyTorch and ONNX. This tutorial solely demonstrates the import process\n",
        "from PyTorch.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span>type List[A] {\n",
              "  Cons(A, List[A]),\n",
              "  Nil,\n",
              "}\n",
              "\n",
              "type Option[A] {\n",
              "  Some(A),\n",
              "  <span style=\"color: #008000; font-weight: bold\">None</span>,\n",
              "}\n",
              "\n",
              "type Tree[A] {\n",
              "  Rose(A, List[Tree[A]]),\n",
              "}\n",
              "\n",
              "type tensor_float16_t {\n",
              "  tensor_nil_float16,\n",
              "  tensor0_float16(float16),\n",
              "  tensor1_float16(Tensor[(?), float16]),\n",
              "  tensor2_float16(Tensor[(?, ?), float16]),\n",
              "  tensor3_float16(Tensor[(?, ?, ?), float16]),\n",
              "  tensor4_float16(Tensor[(?, ?, ?, ?), float16]),\n",
              "  tensor5_float16(Tensor[(?, ?, ?, ?, ?), float16]),\n",
              "  tensor6_float16(Tensor[(?, ?, ?, ?, ?, ?), float16]),\n",
              "}\n",
              "\n",
              "type tensor_float32_t {\n",
              "  tensor_nil_float32,\n",
              "  tensor0_float32(float32),\n",
              "  tensor1_float32(Tensor[(?), float32]),\n",
              "  tensor2_float32(Tensor[(?, ?), float32]),\n",
              "  tensor3_float32(Tensor[(?, ?, ?), float32]),\n",
              "  tensor4_float32(Tensor[(?, ?, ?, ?), float32]),\n",
              "  tensor5_float32(Tensor[(?, ?, ?, ?, ?), float32]),\n",
              "  tensor6_float32(Tensor[(?, ?, ?, ?, ?, ?), float32]),\n",
              "}\n",
              "\n",
              "type tensor_float64_t {\n",
              "  tensor_nil_float64,\n",
              "  tensor0_float64(float64),\n",
              "  tensor1_float64(Tensor[(?), float64]),\n",
              "  tensor2_float64(Tensor[(?, ?), float64]),\n",
              "  tensor3_float64(Tensor[(?, ?, ?), float64]),\n",
              "  tensor4_float64(Tensor[(?, ?, ?, ?), float64]),\n",
              "  tensor5_float64(Tensor[(?, ?, ?, ?, ?), float64]),\n",
              "  tensor6_float64(Tensor[(?, ?, ?, ?, ?, ?), float64]),\n",
              "}\n",
              "\n",
              "type tensor_int16_t {\n",
              "  tensor_nil_int16,\n",
              "  tensor0_int16(int16),\n",
              "  tensor1_int16(Tensor[(?), int16]),\n",
              "  tensor2_int16(Tensor[(?, ?), int16]),\n",
              "  tensor3_int16(Tensor[(?, ?, ?), int16]),\n",
              "  tensor4_int16(Tensor[(?, ?, ?, ?), int16]),\n",
              "  tensor5_int16(Tensor[(?, ?, ?, ?, ?), int16]),\n",
              "  tensor6_int16(Tensor[(?, ?, ?, ?, ?, ?), int16]),\n",
              "}\n",
              "\n",
              "type tensor_int32_t {\n",
              "  tensor_nil_int32,\n",
              "  tensor0_int32(int32),\n",
              "  tensor1_int32(Tensor[(?), int32]),\n",
              "  tensor2_int32(Tensor[(?, ?), int32]),\n",
              "  tensor3_int32(Tensor[(?, ?, ?), int32]),\n",
              "  tensor4_int32(Tensor[(?, ?, ?, ?), int32]),\n",
              "  tensor5_int32(Tensor[(?, ?, ?, ?, ?), int32]),\n",
              "  tensor6_int32(Tensor[(?, ?, ?, ?, ?, ?), int32]),\n",
              "}\n",
              "\n",
              "type tensor_int64_t {\n",
              "  tensor_nil_int64,\n",
              "  tensor0_int64(int64),\n",
              "  tensor1_int64(Tensor[(?), int64]),\n",
              "  tensor2_int64(Tensor[(?, ?), int64]),\n",
              "  tensor3_int64(Tensor[(?, ?, ?), int64]),\n",
              "  tensor4_int64(Tensor[(?, ?, ?, ?), int64]),\n",
              "  tensor5_int64(Tensor[(?, ?, ?, ?, ?), int64]),\n",
              "  tensor6_int64(Tensor[(?, ?, ?, ?, ?, ?), int64]),\n",
              "}\n",
              "\n",
              "type tensor_int8_t {\n",
              "  tensor_nil_int8,\n",
              "  tensor0_int8(int8),\n",
              "  tensor1_int8(Tensor[(?), int8]),\n",
              "  tensor2_int8(Tensor[(?, ?), int8]),\n",
              "  tensor3_int8(Tensor[(?, ?, ?), int8]),\n",
              "  tensor4_int8(Tensor[(?, ?, ?, ?), int8]),\n",
              "  tensor5_int8(Tensor[(?, ?, ?, ?, ?), int8]),\n",
              "  tensor6_int8(Tensor[(?, ?, ?, ?, ?, ?), int8]),\n",
              "}\n",
              "\n",
              "type tensor_uint16_t {\n",
              "  tensor_nil_uint16,\n",
              "  tensor0_uint16(uint16),\n",
              "  tensor1_uint16(Tensor[(?), uint16]),\n",
              "  tensor2_uint16(Tensor[(?, ?), uint16]),\n",
              "  tensor3_uint16(Tensor[(?, ?, ?), uint16]),\n",
              "  tensor4_uint16(Tensor[(?, ?, ?, ?), uint16]),\n",
              "  tensor5_uint16(Tensor[(?, ?, ?, ?, ?), uint16]),\n",
              "  tensor6_uint16(Tensor[(?, ?, ?, ?, ?, ?), uint16]),\n",
              "}\n",
              "\n",
              "type tensor_uint8_t {\n",
              "  tensor_nil_uint8,\n",
              "  tensor0_uint8(uint8),\n",
              "  tensor1_uint8(Tensor[(?), uint8]),\n",
              "  tensor2_uint8(Tensor[(?, ?), uint8]),\n",
              "  tensor3_uint8(Tensor[(?, ?, ?), uint8]),\n",
              "  tensor4_uint8(Tensor[(?, ?, ?, ?), uint8]),\n",
              "  tensor5_uint8(Tensor[(?, ?, ?, ?, ?), uint8]),\n",
              "  tensor6_uint8(Tensor[(?, ?, ?, ?, ?, ?), uint8]),\n",
              "}\n",
              "\n",
              "<span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>input0: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">784</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>input0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>weight: Tensor[(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">784</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>weight:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>bias: Tensor[(<span style=\"color: #008000\">256</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>bias:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>weight: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">256</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>weight:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>bias: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>bias:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>) {\n",
              "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> nn<span style=\"color: #A2F; font-weight: bold\">.</span>dense(<span style=\"color: #A2F; font-weight: bold\">%</span>input0, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>weight, units<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
              "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #A2F; font-weight: bold\">=</span> nn<span style=\"color: #A2F; font-weight: bold\">.</span>bias_add(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>bias, axis<span style=\"color: #A2F; font-weight: bold\">=-</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
              "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #A2F; font-weight: bold\">=</span> nn<span style=\"color: #A2F; font-weight: bold\">.</span>relu(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::relu_0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
              "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span> <span style=\"color: #A2F; font-weight: bold\">=</span> nn<span style=\"color: #A2F; font-weight: bold\">.</span>dense(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>weight, units<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_1:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
              "  nn<span style=\"color: #A2F; font-weight: bold\">.</span>bias_add(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>bias, axis<span style=\"color: #A2F; font-weight: bold\">=-</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_1:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
              "}\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Create a dummy model\n",
        "class TorchModel(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(TorchModel, self).__init__()\n",
        "        self.fc1 = nn.Linear(784, 256)\n",
        "        self.relu1 = nn.ReLU()\n",
        "        self.fc2 = nn.Linear(256, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.fc1(x)\n",
        "        x = self.relu1(x)\n",
        "        x = self.fc2(x)\n",
        "        return x\n",
        "\n",
        "model = TorchModel()\n",
        "input_shape = [1, 784]\n",
        "input_data = torch.randn(input_shape)\n",
        "scripted_model = torch.jit.trace(model, input_data).eval()\n",
        "\n",
        "input_name = \"input0\"\n",
        "shape_list = [(input_name, input_shape)]\n",
        "mod_from_torch, params = relay.frontend.from_pytorch(scripted_model, shape_list)\n",
        "mod_from_torch.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Write with Relax NN Module\n",
        "Apache TVM Unity also provides a set of PyTorch-liked APIs, to help users\n",
        "write the IRModule directly.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
              "\n",
              "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">forward</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
              "        R<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
              "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
              "            permute_dims: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">784</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>permute_dims(fc1_weight, axes<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
              "            matmul: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>matmul(x, permute_dims, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
              "            add: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(matmul, fc1_bias)\n",
              "            relu: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>nn<span style=\"color: #A2F; font-weight: bold\">.</span>relu(add)\n",
              "            permute_dims1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>permute_dims(fc2_weight, axes<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
              "            matmul1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>matmul(relu, permute_dims1, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
              "            add1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(matmul1, fc2_bias)\n",
              "            gv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> add1\n",
              "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
              "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from tvm.relax.frontend import nn\n",
        "\n",
        "class RelaxModel(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(RelaxModel, self).__init__()\n",
        "        self.fc1 = nn.Linear(784, 256)\n",
        "        self.relu1 = nn.ReLU()\n",
        "        self.fc2 = nn.Linear(256, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.fc1(x)\n",
        "        x = self.relu1(x)\n",
        "        x = self.fc2(x)\n",
        "        return x\n",
        "\n",
        "mod_from_relax, params_from_relax = RelaxModel().export_tvm(\n",
        "    {\"forward\": {\"x\": nn.spec.Tensor((1, 784), \"float32\")}}\n",
        ")\n",
        "mod_from_relax.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Create via TVMScript\n",
        "TVMScript is a Python-based DSL for IRModules. We are able to\n",
        "directly output the IRModule in the TVMScript syntax, or alternatively,\n",
        "parse the TVMScript to obtain an IRModule.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
              "\n",
              "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">main</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
              "        R<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
              "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
              "            permute_dims: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">784</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>permute_dims(fc1_weight, axes<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
              "            matmul: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>matmul(x, permute_dims, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
              "            add: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(matmul, fc1_bias)\n",
              "            relu: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>nn<span style=\"color: #A2F; font-weight: bold\">.</span>relu(add)\n",
              "            permute_dims1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>permute_dims(fc2_weight, axes<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
              "            matmul1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>matmul(relu, permute_dims1, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
              "            add1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(matmul1, fc2_bias)\n",
              "            gv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> add1\n",
              "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
              "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from tvm.script import ir as I\n",
        "from tvm.script import relax as R\n",
        "\n",
        "@I.ir_module\n",
        "class TVMScriptModule:\n",
        "    @R.function\n",
        "    def main(\n",
        "        x: R.Tensor((1, 784), dtype=\"float32\"),\n",
        "        fc1_weight: R.Tensor((256, 784), dtype=\"float32\"),\n",
        "        fc1_bias: R.Tensor((256,), dtype=\"float32\"),\n",
        "        fc2_weight: R.Tensor((10, 256), dtype=\"float32\"),\n",
        "        fc2_bias: R.Tensor((10,), dtype=\"float32\"),\n",
        "    ) -> R.Tensor((1, 10), dtype=\"float32\"):\n",
        "        R.func_attr({\"num_input\": 1})\n",
        "        with R.dataflow():\n",
        "            permute_dims = R.permute_dims(fc1_weight, axes=None)\n",
        "            matmul = R.matmul(x, permute_dims, out_dtype=\"void\")\n",
        "            add = R.add(matmul, fc1_bias)\n",
        "            relu = R.nn.relu(add)\n",
        "            permute_dims1 = R.permute_dims(fc2_weight, axes=None)\n",
        "            matmul1 = R.matmul(relu, permute_dims1, out_dtype=\"void\")\n",
        "            add1 = R.add(matmul1, fc2_bias)\n",
        "            gv = add1\n",
        "            R.output(gv)\n",
        "        return gv\n",
        "\n",
        "mod_from_script = TVMScriptModule\n",
        "mod_from_script.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Attributes of an IRModule\n",
        "An IRModule is a collection of functions, indexed by GlobalVars.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[I.GlobalVar(\"main\")]\n"
          ]
        }
      ],
      "source": [
        "mod = mod_from_torch\n",
        "print(mod.get_global_vars())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can access the functions in the IRModule by indexing with the GlobalVars\n",
        "or their names\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "fn (%input0: Tensor[(1, 784), float32] /* span=aten::linear_0.input0:0:0 */, %aten::linear_0.weight: Tensor[(256, 784), float32] /* span=aten::linear_0.weight:0:0 */, %aten::linear_0.bias: Tensor[(256), float32] /* span=aten::linear_0.bias:0:0 */, %aten::linear_1.weight: Tensor[(10, 256), float32] /* span=aten::linear_1.weight:0:0 */, %aten::linear_1.bias: Tensor[(10), float32] /* span=aten::linear_1.bias:0:0 */) {\n",
            "  %0 = nn.dense(%input0, %aten::linear_0.weight, units=None) /* span=aten::linear_0:0:0 */;\n",
            "  %1 = nn.bias_add(%0, %aten::linear_0.bias, axis=-1) /* span=aten::linear_0:0:0 */;\n",
            "  %2 = nn.relu(%1) /* span=aten::relu_0:0:0 */;\n",
            "  %3 = nn.dense(%2, %aten::linear_1.weight, units=None) /* span=aten::linear_1:0:0 */;\n",
            "  nn.bias_add(%3, %aten::linear_1.bias, axis=-1) /* span=aten::linear_1:0:0 */\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "# index by global var name\n",
        "print(mod[\"main\"])\n",
        "# index by global var, and checking they are the same function\n",
        "(gv,) = mod.get_global_vars()\n",
        "assert mod[gv] == mod[\"main\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Transformations on IRModules\n",
        "Transformations are the import component of Apache TVM Unity. One transformation\n",
        "takes in an IRModule and outputs another IRModule. We can apply a sequence of\n",
        "transformations to an IRModule to obtain a new IRModule. That is the common way to\n",
        "optimize a model.\n",
        "\n",
        "In this getting started tutorial, we only demonstrate how to apply transformations\n",
        "to an IRModule. For details of each transformation, please refer to the\n",
        "`Transformation API Reference <api-relax-transformation>`\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We first apply **LegalizeOps** transformation to the IRModule. This transformation\n",
        "will convert the Relax module into a mixed stage, with both Relax and TensorIR function\n",
        "within the same module. Meanwhile, the Relax operators will be converted into ``call_tir``.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span>type List[A] {\n",
              "  Cons(A, List[A]),\n",
              "  Nil,\n",
              "}\n",
              "\n",
              "type Option[A] {\n",
              "  Some(A),\n",
              "  <span style=\"color: #008000; font-weight: bold\">None</span>,\n",
              "}\n",
              "\n",
              "type Tree[A] {\n",
              "  Rose(A, List[Tree[A]]),\n",
              "}\n",
              "\n",
              "type tensor_float16_t {\n",
              "  tensor_nil_float16,\n",
              "  tensor0_float16(float16),\n",
              "  tensor1_float16(Tensor[(?), float16]),\n",
              "  tensor2_float16(Tensor[(?, ?), float16]),\n",
              "  tensor3_float16(Tensor[(?, ?, ?), float16]),\n",
              "  tensor4_float16(Tensor[(?, ?, ?, ?), float16]),\n",
              "  tensor5_float16(Tensor[(?, ?, ?, ?, ?), float16]),\n",
              "  tensor6_float16(Tensor[(?, ?, ?, ?, ?, ?), float16]),\n",
              "}\n",
              "\n",
              "type tensor_float32_t {\n",
              "  tensor_nil_float32,\n",
              "  tensor0_float32(float32),\n",
              "  tensor1_float32(Tensor[(?), float32]),\n",
              "  tensor2_float32(Tensor[(?, ?), float32]),\n",
              "  tensor3_float32(Tensor[(?, ?, ?), float32]),\n",
              "  tensor4_float32(Tensor[(?, ?, ?, ?), float32]),\n",
              "  tensor5_float32(Tensor[(?, ?, ?, ?, ?), float32]),\n",
              "  tensor6_float32(Tensor[(?, ?, ?, ?, ?, ?), float32]),\n",
              "}\n",
              "\n",
              "type tensor_float64_t {\n",
              "  tensor_nil_float64,\n",
              "  tensor0_float64(float64),\n",
              "  tensor1_float64(Tensor[(?), float64]),\n",
              "  tensor2_float64(Tensor[(?, ?), float64]),\n",
              "  tensor3_float64(Tensor[(?, ?, ?), float64]),\n",
              "  tensor4_float64(Tensor[(?, ?, ?, ?), float64]),\n",
              "  tensor5_float64(Tensor[(?, ?, ?, ?, ?), float64]),\n",
              "  tensor6_float64(Tensor[(?, ?, ?, ?, ?, ?), float64]),\n",
              "}\n",
              "\n",
              "type tensor_int16_t {\n",
              "  tensor_nil_int16,\n",
              "  tensor0_int16(int16),\n",
              "  tensor1_int16(Tensor[(?), int16]),\n",
              "  tensor2_int16(Tensor[(?, ?), int16]),\n",
              "  tensor3_int16(Tensor[(?, ?, ?), int16]),\n",
              "  tensor4_int16(Tensor[(?, ?, ?, ?), int16]),\n",
              "  tensor5_int16(Tensor[(?, ?, ?, ?, ?), int16]),\n",
              "  tensor6_int16(Tensor[(?, ?, ?, ?, ?, ?), int16]),\n",
              "}\n",
              "\n",
              "type tensor_int32_t {\n",
              "  tensor_nil_int32,\n",
              "  tensor0_int32(int32),\n",
              "  tensor1_int32(Tensor[(?), int32]),\n",
              "  tensor2_int32(Tensor[(?, ?), int32]),\n",
              "  tensor3_int32(Tensor[(?, ?, ?), int32]),\n",
              "  tensor4_int32(Tensor[(?, ?, ?, ?), int32]),\n",
              "  tensor5_int32(Tensor[(?, ?, ?, ?, ?), int32]),\n",
              "  tensor6_int32(Tensor[(?, ?, ?, ?, ?, ?), int32]),\n",
              "}\n",
              "\n",
              "type tensor_int64_t {\n",
              "  tensor_nil_int64,\n",
              "  tensor0_int64(int64),\n",
              "  tensor1_int64(Tensor[(?), int64]),\n",
              "  tensor2_int64(Tensor[(?, ?), int64]),\n",
              "  tensor3_int64(Tensor[(?, ?, ?), int64]),\n",
              "  tensor4_int64(Tensor[(?, ?, ?, ?), int64]),\n",
              "  tensor5_int64(Tensor[(?, ?, ?, ?, ?), int64]),\n",
              "  tensor6_int64(Tensor[(?, ?, ?, ?, ?, ?), int64]),\n",
              "}\n",
              "\n",
              "type tensor_int8_t {\n",
              "  tensor_nil_int8,\n",
              "  tensor0_int8(int8),\n",
              "  tensor1_int8(Tensor[(?), int8]),\n",
              "  tensor2_int8(Tensor[(?, ?), int8]),\n",
              "  tensor3_int8(Tensor[(?, ?, ?), int8]),\n",
              "  tensor4_int8(Tensor[(?, ?, ?, ?), int8]),\n",
              "  tensor5_int8(Tensor[(?, ?, ?, ?, ?), int8]),\n",
              "  tensor6_int8(Tensor[(?, ?, ?, ?, ?, ?), int8]),\n",
              "}\n",
              "\n",
              "type tensor_uint16_t {\n",
              "  tensor_nil_uint16,\n",
              "  tensor0_uint16(uint16),\n",
              "  tensor1_uint16(Tensor[(?), uint16]),\n",
              "  tensor2_uint16(Tensor[(?, ?), uint16]),\n",
              "  tensor3_uint16(Tensor[(?, ?, ?), uint16]),\n",
              "  tensor4_uint16(Tensor[(?, ?, ?, ?), uint16]),\n",
              "  tensor5_uint16(Tensor[(?, ?, ?, ?, ?), uint16]),\n",
              "  tensor6_uint16(Tensor[(?, ?, ?, ?, ?, ?), uint16]),\n",
              "}\n",
              "\n",
              "type tensor_uint8_t {\n",
              "  tensor_nil_uint8,\n",
              "  tensor0_uint8(uint8),\n",
              "  tensor1_uint8(Tensor[(?), uint8]),\n",
              "  tensor2_uint8(Tensor[(?, ?), uint8]),\n",
              "  tensor3_uint8(Tensor[(?, ?, ?), uint8]),\n",
              "  tensor4_uint8(Tensor[(?, ?, ?, ?), uint8]),\n",
              "  tensor5_uint8(Tensor[(?, ?, ?, ?, ?), uint8]),\n",
              "  tensor6_uint8(Tensor[(?, ?, ?, ?, ?, ?), uint8]),\n",
              "}\n",
              "\n",
              "<span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>input0: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">784</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>input0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>weight: Tensor[(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">784</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>weight:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>bias: Tensor[(<span style=\"color: #008000\">256</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>bias:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>weight: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">256</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>weight:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>bias: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>bias:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>) {\n",
              "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> nn<span style=\"color: #A2F; font-weight: bold\">.</span>dense(<span style=\"color: #A2F; font-weight: bold\">%</span>input0, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>weight, units<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
              "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #A2F; font-weight: bold\">=</span> nn<span style=\"color: #A2F; font-weight: bold\">.</span>bias_add(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_0<span style=\"color: #A2F; font-weight: bold\">.</span>bias, axis<span style=\"color: #A2F; font-weight: bold\">=-</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
              "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #A2F; font-weight: bold\">=</span> nn<span style=\"color: #A2F; font-weight: bold\">.</span>relu(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::relu_0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
              "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span> <span style=\"color: #A2F; font-weight: bold\">=</span> nn<span style=\"color: #A2F; font-weight: bold\">.</span>dense(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>weight, units<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_1:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
              "  nn<span style=\"color: #A2F; font-weight: bold\">.</span>bias_add(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>aten::linear_1<span style=\"color: #A2F; font-weight: bold\">.</span>bias, axis<span style=\"color: #A2F; font-weight: bold\">=-</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> span<span style=\"color: #A2F; font-weight: bold\">=</span>aten::linear_1:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
              "}\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "mod = mod_from_torch\n",
        "mod = relax.transform.LegalizeOps()(mod)\n",
        "mod.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "After the transformation, there are much more functions inside the module. Let's print\n",
        "the global vars again.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[I.GlobalVar(\"main\")]\n"
          ]
        }
      ],
      "source": [
        "print(mod.get_global_vars())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next, Apache TVM Unity provides a set of default transformation pipelines for users,\n",
        "to simplify the transformation process. We can then apply the default pipeline to the module.\n",
        "The default **zero** pipeline contains very fundamental transformations, including:\n",
        "\n",
        "- **LegalizeOps**: This transform converts the Relax operators into `call_tir` functions\n",
        "  with the corresponding TensorIR Functions. After this transform, the IRModule will\n",
        "  contain both Relax functions and TensorIR functions.\n",
        "- **AnnotateTIROpPattern**: This transform annotates the pattern of the TensorIR functions,\n",
        "  preparing them for subsequent operator fusion.\n",
        "- **FoldConstant**: This pass performs constant folding, optimizing operations\n",
        "  involving constants.\n",
        "- **FuseOps and FuseTIR**: These two passes work together to fuse operators based on the\n",
        "  patterns annotated in the previous step (AnnotateTIROpPattern). These passes transform\n",
        "  both Relax functions and TensorIR functions.\n",
        "\n",
        "<div class=\"alert alert-info\"><h4>Note</h4><p>Here, we have applied **LegalizeOps** twice in the flow. The second time is useless but\n",
        "  harmless.\n",
        "\n",
        "  Every passes can be duplicated in the flow, since we ensure the passes can handle all legal\n",
        "  IRModule inputs. This design can help users to construct their own pipeline.</p></div>\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
              "\n",
              "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func(private<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">fused_matmul1_add1</span>(relu: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), permute_dims1: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_bias: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>),), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), T_add_intermediate: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        matmul_intermediate <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>alloc_buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)))\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i0, i1, k <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;matmul&quot;</span>):\n",
              "                v_i0, v_i1, v_k <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SSR&quot;</span>, [i0, i1, k])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(matmul_intermediate[v_i0, v_i1])\n",
              "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>init():\n",
              "                    matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)\n",
              "                matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">+</span> relu[v_i0, v_k] <span style=\"color: #A2F; font-weight: bold\">*</span> permute_dims1[v_k, v_i1]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> ax0, ax1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;T_add&quot;</span>):\n",
              "                v_ax0, v_ax1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [ax0, ax1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(T_add_intermediate[v_ax0, v_ax1])\n",
              "                T_add_intermediate[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">=</span> matmul_intermediate[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">+</span> fc2_bias[v_ax1]\n",
              "\n",
              "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func(private<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">fused_matmul_add_relu</span>(x: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), permute_dims: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_bias: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>),), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), compute_intermediate: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        matmul_intermediate <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>alloc_buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)))\n",
              "        T_add_intermediate <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>alloc_buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)))\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i0, i1, k <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>)):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;matmul&quot;</span>):\n",
              "                v_i0, v_i1, v_k <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SSR&quot;</span>, [i0, i1, k])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(x[v_i0, v_k], permute_dims[v_k, v_i1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(matmul_intermediate[v_i0, v_i1])\n",
              "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>init():\n",
              "                    matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)\n",
              "                matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">+</span> x[v_i0, v_k] <span style=\"color: #A2F; font-weight: bold\">*</span> permute_dims[v_k, v_i1]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> ax0, ax1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;T_add&quot;</span>):\n",
              "                v_ax0, v_ax1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [ax0, ax1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(T_add_intermediate[v_ax0, v_ax1])\n",
              "                T_add_intermediate[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">=</span> matmul_intermediate[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">+</span> fc1_bias[v_ax1]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i0, i1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;compute&quot;</span>):\n",
              "                v_i0, v_i1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i0, i1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(T_add_intermediate[v_i0, v_i1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(compute_intermediate[v_i0, v_i1])\n",
              "                compute_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>max(T_add_intermediate[v_i0, v_i1], T<span style=\"color: #A2F; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>))\n",
              "\n",
              "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func(private<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">transpose</span>(fc1_weight: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), T_transpose: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;op_pattern&quot;</span>: <span style=\"color: #008000\">2</span>, <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> ax0, ax1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;T_transpose&quot;</span>):\n",
              "                v_ax0, v_ax1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [ax0, ax1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(fc1_weight[v_ax1, v_ax0])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(T_transpose[v_ax0, v_ax1])\n",
              "                T_transpose[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">=</span> fc1_weight[v_ax1, v_ax0]\n",
              "\n",
              "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func(private<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">transpose1</span>(fc2_weight: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), T_transpose: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;op_pattern&quot;</span>: <span style=\"color: #008000\">2</span>, <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #A2F; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> ax0, ax1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;T_transpose&quot;</span>):\n",
              "                v_ax0, v_ax1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [ax0, ax1])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(fc2_weight[v_ax1, v_ax0])\n",
              "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(T_transpose[v_ax0, v_ax1])\n",
              "                T_transpose[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">=</span> fc2_weight[v_ax1, v_ax0]\n",
              "\n",
              "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">forward</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
              "        R<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
              "        cls <span style=\"color: #A2F; font-weight: bold\">=</span> Module\n",
              "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
              "            permute_dims <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #A2F; font-weight: bold\">.</span>transpose, (fc1_weight,), out_sinfo<span style=\"color: #A2F; font-weight: bold\">=</span>R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">784</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
              "            lv <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #A2F; font-weight: bold\">.</span>fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo<span style=\"color: #A2F; font-weight: bold\">=</span>R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
              "            permute_dims1 <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #A2F; font-weight: bold\">.</span>transpose1, (fc2_weight,), out_sinfo<span style=\"color: #A2F; font-weight: bold\">=</span>R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
              "            gv <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #A2F; font-weight: bold\">.</span>fused_matmul1_add1, (lv, permute_dims1, fc2_bias), out_sinfo<span style=\"color: #A2F; font-weight: bold\">=</span>R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
              "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
              "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "mod, param_spec = RelaxModel().export_tvm(\n",
        "    spec={\"forward\": {\"x\": nn.spec.Tensor((1, 784), \"float32\")}}\n",
        ")\n",
        "# mod.show()\n",
        "mod = relax.get_pipeline(\"zero\")(mod)\n",
        "mod.show()\n",
        "# print(mod.get_global_vars())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Deploy the IRModule Universally\n",
        "After the optimization, we can compile the model into a TVM runtime module.\n",
        "Notably, Apache TVM Unity provides the ability of universal deployment, which means\n",
        "we can deploy the same IRModule on different backends, including CPU, GPU, and other emerging\n",
        "backends.\n",
        "\n",
        "### Deploy on CPU\n",
        "We can deploy the IRModule on CPU by specifying the target as ``llvm``.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[[25061.188 24633.898 26877.295 23512.16  24645.11  24105.252 24847.125\n",
            "  23395.107 23377.137 26118.697]]\n"
          ]
        }
      ],
      "source": [
        "exec = relax.build(mod, target=\"llvm\")\n",
        "dev = tvm.cpu()\n",
        "vm = relax.VirtualMachine(exec, dev)\n",
        "\n",
        "raw_data = np.random.rand(1, 784).astype(\"float32\")\n",
        "data = tvm.nd.array(raw_data, dev)\n",
        "\n",
        "device = tvm.cpu()\n",
        "param_spec = params_from_relax\n",
        "params = [np.random.rand(*param.shape).astype(\"float32\") for _, param in param_spec]\n",
        "params = [tvm.nd.array(param, device=device) for param in params]\n",
        "\n",
        "cpu_out = vm[\"forward\"](data, *params).numpy()\n",
        "print(cpu_out)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Deploy on GPU\n",
        "Besides, CPU backend, we can also deploy the IRModule on GPU. GPU requires\n",
        "programs containing extra information, such as the thread bindings and shared memory\n",
        "allocations. We need a further transformation to generate the GPU programs.\n",
        "\n",
        "We use ``DLight`` to generate the GPU programs. In this tutorial, we won't go into\n",
        "the details of ``DLight``.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from tvm import dlight as dl\n",
        "\n",
        "with tvm.target.Target(\"cuda\"):\n",
        "    gpu_mod = dl.ApplyDefaultSchedule(\n",
        "        dl.gpu.Matmul(),\n",
        "        dl.gpu.Fallback(),\n",
        "    )(mod)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we can compile the IRModule on GPU, the similar way as we did on CPU.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Deploy on Other Backends\n",
        "Apache TVM Unity also supports other backends, such as different kinds of GPUs\n",
        "(Metal, ROCm, Vulkan and OpenCL), different kinds of CPUs (x86, ARM), and other\n",
        "emerging backends (e.g., WebAssembly). The deployment process is similar to the\n",
        "GPU backend.\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "tvm-build",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
